
import functools
import inspect

from typing import List

from wintersweet.asyncs.utils.futures import RunAbleFuture
from wintersweet.framework.fastapi.limiter.limit import KeyRateLimit, LimitGroup, APIRateLimit, BaseLimit
from wintersweet.framework.fastapi.limiter.util import get_remote_addr


class APILimiter:

    def __init__(self, key=None):

        self._key = key or get_remote_addr
        self._limits = {}
        self._api_limits = {}

    def limit(
            self,
            times=5,
            seconds=60,
            forbid_seconds=0,
            key=None,
            max_tags_size=None,
            limit_key: str=None,
            redis=None,
            ntp=None,
            forbid_response_factory=None,
            limits: List[BaseLimit] = None
    ):
        """
        用户访问频率控制
        :param times: Number of visits within a certain period of time
        :param seconds: The time range of user visits
        :param forbid_seconds: If the user accesses it "times" times within "senconds"s,
                               it will be forbidden to access it for "forbid_seconds"s
        :param key: a function That can return user`s unique tag
        :param max_tags_size: The max number to record users
        :param limit_key:
        :param redis: can give a redis_pool name
        :param ntp: ntp server, like the obj of 'IntervalNTPClientManager' or IntervalNTPClient
        :param forbid_response_factory: Return the object of this class if the user has been denied access
        :param limits: The other limits， You can expand it
        """

        def decorator(func):
            _limit_key = limit_key or f'{func.__module__}.{func.__name__}'

            _limits = [
                KeyRateLimit(
                    times=int(times),
                    seconds=int(seconds),
                    forbid_seconds=int(forbid_seconds),
                    key=key or self._key,
                    name=_limit_key,
                    max_tags_size=max_tags_size,
                    redis=redis,
                    ntp=ntp,
                    forbid_response_factory=forbid_response_factory
                )
            ]

            if limits:
                _limits.extend(limits)

            limit_group = LimitGroup(_limits)

            self._limits.setdefault(_limit_key, limit_group)

            sig = inspect.signature(func)
            for index, parameter in enumerate(sig.parameters.values()):
                if parameter.name == "request" or parameter.name == "websocket":
                    break
            else:
                raise Exception(
                    f'No "request" or "websocket" argument on function "{func}"'
                )

            @functools.wraps(func)
            async def wrapper(*args, **kwargs):
                request = kwargs.get('request') or args[index]
                response = await self._limits[_limit_key].check_request_limited(request)
                if response is False:
                    return await func(*args, **kwargs)

                return response

            return wrapper

        return decorator

    def api_limit(
            self,
            running_size=10,
            waiting_size=10,
            limit_key=None,
            forbid_response_factory=None,
            task_factory=RunAbleFuture
    ):
        """
        api访问速率频率限制，实现运行池和等待运行池，运行池和等待运行池满后的请求将被拒绝
        运行池内的请求完成后将继续运行等待运行池内的请求

        :param running_size: 运行池大小，允许最大running_size个请求同时处理
        :param waiting_size: 等待运行池大小，允许最大有waiting_size个请求等待处理
        :param limit_key: 接口唯一标识，可通过该参数控制几个接口共用同一个速率控制器，此时速率控制器将以第一个被读入内存的控制器为最终速率控制器
        :param forbid_response_factory: response工厂类
        :param task_factory: 可等待任务工厂类
        :return:
        """
        rate_limit = APIRateLimit(running_size, waiting_size, forbid_response_factory, task_factory)

        def decorator(func):
            _limit_key = limit_key or f'{func.__module__}.{func.__name__}'
            self._api_limits.setdefault(_limit_key, rate_limit)

            @functools.wraps(func)
            async def wrapper(*args, **kwargs):
                f = self._api_limits[_limit_key].append(func, args, kwargs)
                if f is None:
                    return self._api_limits[_limit_key].forbid_response_factory('Service Busy')
                else:
                    return await f

            return wrapper

        return decorator
